################################
## Script: 09_estimate_recall.R
## Purpose: This code runs the recall test for gpt4o, ngram,
## and full pipeline.
## Data In:
## 1) recall set (training):
## data/recall_master_updated_11_24_2024_with_sbert_cross_score_public.rds
## 2) recall set (holdout):
## data/recall_holdout_combined_with_sbert_cross_score_public.rds
## 3) ngram output:
## data/ngrams_edges.rds
## 4) relatio output:
## data/relatio_edges.rds
## 5) all articles:
## data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json
## 6) list of claims and subjects
## all files with the form
## data/gpt4o_annotations/annotate_lists_[num].rds
## where "num" varies from 1-100
## 7) STM output:
## data/stm_pairs.rds
## 8) fine tuning training data:
## a) same claim training ids (all article ids used in same claim fine tuning)
## data/train_fine_tune_ids.rds
## b) same subject training ids  (all article ids used in same subject fine tuning)
## data/train_fine_tune_ids_subject.rds
## c) same claim training data (all article pairs used in same claim fine tuning)
## data/fine_tune_same_claim_training_data.rds
## d) same subject training data (all article pairs used in same subject fine tuning)
## data/fine_tune_same_subject_training_data.rds
## 9) Fine tune annotations
## data/gpt_annotations_finetune_full.rds
## 10) Zero shot annotations
## data/gpt_annotations_full.rds
## Data Out:
## 1) recall training data with data on whether pairs recalled by each estimator:
## a) data/recall_master_updated_11_24_2024_with_all_recall_details_public.rds"
## 2) recall holdout data with data on whether pairs recalled by each estimator:
## b) data/recall_holdout_with_all_recall_estimates_public.rds"
## Notes:

library(tidyverse) ## 2.0.0
library(openai) ## 0.4.1

##########################################
### Dependencies #########################
##########################################

'%ni%' <- Negate("%in%")

## recall training data set
## includes sbert scores 
recall <- readRDS("data/recall_master_updated_11_24_2024_with_sbert_cross_score_public.rds")

## recall holdout set
recall_holdout <- readRDS("data/recall_holdout_combined_with_sbert_cross_score_public.rds")
recall_holdout <- recall_holdout %>%
  filter(same_claim == "YES")

## ngram data
ngrams <- readRDS("data/ngrams_edges.rds")

## including match_id
mat <- which(colnames(ngrams) == "ego_id" |
               colnames(ngrams) == "alter_id")
ngrams$match_id <- apply(ngrams, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})


## relatio data
relatio <- readRDS("data/relatio_edges.rds")
## match_id 
mat <- which(colnames(relatio) == "ego_id" |
               colnames(relatio) == "alter_id")
relatio$match_id <- apply(relatio, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})


## total articles 
total_art <- jsonlite::stream_in(file("data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json"))


## list of claims and subjects for recall
list_annotations <- list.files("data/gpt4o_annotations/")
list_annotations <- list_annotations[grepl("list", list_annotations)]
list_annotations_num <- gsub("annotate_lists_|\\.rds", "", list_annotations)
list_annotations_num <- as.numeric(list_annotations_num)

list_annotations <- paste0("data/gpt4o_annotations/",
                           list_annotations)

## read in lists and merge in data
total_art <- split(total_art, 1:100)

for(i in 1:100){
  index <- which(list_annotations_num == i)
  file_toread <- list_annotations[index]
  file_toread <- readRDS(file_toread)
  subject <- file_toread[[1]]
  subject <- unlist(lapply(subject, function(x){return(x$choices$message.content)}))
  
  claim <- file_toread[[2]]
  claim <- unlist(lapply(claim, function(x){return(x$choices$message.content)}))
  
  total_art[[i]]$subject_list <- subject
  total_art[[i]]$claim_list <- claim
}

total_art <- bind_rows(total_art)

## merge in ego_claim
## alter_claim,
## ego_subject, alter_subject lists
## and dates of articles
recall$ego_claim <- total_art$claim_list[match(recall$focal_article,
                                            total_art$article_id)]
recall$alter_claim <- total_art$claim_list[match(recall$article_id,
                                              total_art$article_id)]

recall$ego_subject <- total_art$subject_list[match(recall$focal_article,
                                                total_art$article_id)]
recall$alter_subject <- total_art$subject_list[match(recall$article_id,
                                                  total_art$article_id)]
recall$ego_date <- as.Date(total_art$final_date[match(recall$focal_article,
                                                   total_art$article_id)])
recall_holdout$ego_date <- as.Date(total_art$final_date[match(recall_holdout$focal_article,
                                                      total_art$article_id)])
recall$alter_date <- as.Date(total_art$final_date[match(recall$article_id,
                                                     total_art$article_id)])
recall_holdout$alter_date <- as.Date(total_art$final_date[match(recall_holdout$article_id,
                                                              total_art$article_id)])

## read in zero shot data:
matches <- readRDS("data/gpt_annotations_full.rds")
mat <- which(colnames(matches) == "ego_id" |
               colnames(matches) == "alter_id")
matches$match_id <- apply(matches, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})
table(matches$gpt_annotation) ## YES: 18,138

## STM data
stm_pairs <- readRDS("data/stm_pairs.rds")
mat <- which(colnames(stm_pairs) == "ego_id" |
               colnames(stm_pairs) == "alter_id")
stm_pairs$match_id <- apply(stm_pairs, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})
nrow(stm_pairs) ## 34,210

## gpt4o fine tune training ids
fine_tune_train_ids <- readRDS("data/train_fine_tune_ids.rds")
fine_tune_train_ids_same_subject <- readRDS("data/train_fine_tune_ids_subject.rds")

## gpt4o fine tuning actual pairs used
fine_tune_same_claim_data <- readRDS("data/fine_tune_same_claim_training_data.rds")
fine_tune_same_subject_data <- readRDS("data/fine_tune_same_subject_training_data.rds")

## fine tune labels 
finetune <- readRDS("data/gpt_annotations_finetune_full.rds")
mat <- which(colnames(finetune) == "ego_id" |
               colnames(finetune) == "alter_id")
finetune$match_id <- apply(finetune, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})
finetune <- finetune %>%
  filter(gpt4o_finetune == "YES")
nrow(finetune) ## 4,204


#######################################
### Calculate Recall #################
#######################################

## making sure we have no data leakage:

## 1) was the recall pair used in the fine tuning?


recall$train_gpt4o_finetune_pairs <- ifelse(recall$match_id %in% c(fine_tune_same_claim_data$match_id,
                                                                   fine_tune_same_subject_data$match_id),
                                          "training data", "not training data")
recall_holdout$train_gpt4o_finetune_pairs <- ifelse(recall_holdout$match_id %in% c(fine_tune_same_claim_data$match_id,
                                                                   fine_tune_same_subject_data$match_id),
                                            "training data", "not training data")

table(recall$train_gpt4o_finetune_pairs) ## 0
table(recall_holdout$train_gpt4o_finetune_pairs) ## 0 

## 2) was the recall article used at all in the fine
## tuning? NOTE: this doesn't mean the pair was included,
## this is more stringent definition that an article
## was included in any pair. We report this as a robustness check. 
## the fine tuning was done long after we constructed the recall set 
## so we couldn't eliminate these cases. 

recall$train_gpt4o_finetune_ids <- ifelse(recall$article_id %in% c(fine_tune_train_ids,
                                                               fine_tune_train_ids_same_subject) |
                                        recall$focal_article %in% c(fine_tune_train_ids,
                                                                    fine_tune_train_ids_same_subject),
                                      "training data", "not training data")
recall_holdout$train_gpt4o_finetune_ids <- ifelse(recall_holdout$article_id %in% c(fine_tune_train_ids,
                                                                   fine_tune_train_ids_same_subject) |
                                                    recall_holdout$focal_article %in% c(fine_tune_train_ids,
                                                                        fine_tune_train_ids_same_subject),
                                          "training data", "not training data")
table(recall$train_gpt4o_finetune_ids) ## 44 not included in training data in any fashion
table(recall_holdout$train_gpt4o_finetune_ids) ## 26 not included in training data in any fashion 



## total predicted pairs
## 1) sbert llm finetune
nrow(finetune) ## 4,204
## b) sbert llm zero shot
sum(matches$gpt_annotation == "YES") ## 18,138
## 3) stm
nrow(stm_pairs)  ## 34,210
## 4) ngram
table(ngrams$stratum)
## 5) relatio
table(relatio$stratum)


## creating variables

## 1) SBERT-LLM zero shot
recall$llm_zero_shot <- recall$match_id %in% matches$match_id[matches$gpt_annotation == "YES"]
recall_holdout$llm_zero_shot <- recall_holdout$match_id %in% matches$match_id[matches$gpt_annotation == "YES"]

## 2) SBERT LLM Fine Tuning
recall$llm_fine_tune <- recall$match_id %in% finetune$match_id
recall_holdout$llm_fine_tune <- recall_holdout$match_id %in% finetune$match_id

## 3) SBERT LLM Fine Tuning with Stringent Exclusion of Training Data
recall$llm_fine_tune_stringent <- ifelse(recall$train_gpt4o_finetune_ids == "training data",
                                         NA,
                                         ifelse(recall$match_id %in% finetune$match_id,
                                                TRUE,
                                                FALSE))
recall_holdout$llm_fine_tune_stringent <- ifelse(recall_holdout$train_gpt4o_finetune_ids == "training data",
                                         NA,
                                         ifelse(recall_holdout$match_id %in% finetune$match_id,
                                                TRUE,
                                                FALSE))

## 3) NGRAMS

## a) .2 threshold
min(ngrams$sim) ## that was minimum (5-word gram)
recall$ngram_2 <- recall$match_id %in% ngrams$match_id
recall_holdout$ngram_2 <- recall_holdout$match_id %in% ngrams$match_id

## b) .4 threshold
recall$ngram_4 <- recall$match_id %in% ngrams$match_id[ngrams$sim >= .4]
recall_holdout$ngram_4 <- recall_holdout$match_id %in% ngrams$match_id[ngrams$sim >= .4]

## c) .6 threshold 
recall$ngram_6 <- recall$match_id %in% ngrams$match_id[ngrams$sim >= .6]
recall_holdout$ngram_6 <- recall_holdout$match_id %in% ngrams$match_id[ngrams$sim >= .6]

## 4) STM
recall$STM_cluster <- recall$match_id %in% stm_pairs$match_id
recall_holdout$STM_cluster <- recall_holdout$match_id %in% stm_pairs$match_id


## 5) RELATIO
## a) .1 
min(relatio$sim) ## that was the minimum cutoff
recall$relatio_1 <- recall$match_id %in% relatio$match_id
recall_holdout$relatio_1 <- recall_holdout$match_id %in% relatio$match_id

## b). 2
recall$relatio_2 <- recall$match_id %in% relatio$match_id[relatio$stratum == ".2 < .4" |
                                                            relatio$stratum == ".4 < .6" |
                                                       relatio$stratum == ".6+"]
recall_holdout$relatio_2 <- recall_holdout$match_id %in% relatio$match_id[relatio$stratum == ".2 < .4" |
                                                                            relatio$stratum == ".4 < .6" |
                                                                            relatio$stratum == ".6+"]

## c) .4
recall$relatio_4 <- recall$match_id %in% relatio$match_id[relatio$stratum == ".4 < .6" |
                                                            relatio$stratum == ".6+"]
recall_holdout$relatio_4 <- recall_holdout$match_id %in% relatio$match_id[relatio$stratum == ".4 < .6" |
                                                                            relatio$stratum == ".6+"]

## d) .6 
recall$relatio_6 <- recall$match_id %in% relatio$match_id[relatio$stratum == ".6+"]
recall_holdout$relatio_6 <- recall_holdout$match_id %in% relatio$match_id[relatio$stratum == ".6+"]


## saving recall data
## we use these in precision coding file
## to create combined table:

#saveRDS(recall, "data/recall_master_updated_11_24_2024_with_all_recall_details_public.rds")
#saveRDS(recall_holdout, "data/recall_holdout_with_all_recall_estimates_public.rds")

